Source code for hysop.backend.device.autotunable_kernel

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools, itertools as it

from abc import ABCMeta, abstractmethod
from hysop.tools.htypes import check_instance, first_not_None
from hysop.tools.numpywrappers import npw
from hysop.tools.misc import next_pow2, upper_pow2
from hysop.backend.device.kernel_autotuner_config import KernelAutotunerConfig
from hysop.backend.device.codegen.structs.mesh_info import MeshInfoStruct
from hysop.fields.cartesian_discrete_field import CartesianDiscreteScalarFieldView


[docs] class AutotunableKernel(metaclass=ABCMeta): def __init__( self, autotuner_config, build_opts, dump_src=None, symbolic_mode=None, **kwds ): super().__init__(**kwds) self._check_build_configuration(autotuner_config, build_opts) self.autotuner_config = autotuner_config self.build_opts = build_opts self.dump_src = first_not_None(dump_src, autotuner_config.debug) self.symbolic_mode = first_not_None(symbolic_mode, autotuner_config.debug)
[docs] def custom_hash(self, *args, **kwds): HASH_DEBUG = self.autotuner_config.dump_hash_logs assert args or kwds, "no arguments to be hashed." def _hash_arg(a): s = "" if a is None: s += "\nNone" h = hash("None") elif a is Ellipsis: s += "\nEllipsis" h = hash("Ellipsis") elif isinstance(a, str): if HASH_DEBUG: s += f"\n>HASHING STR: {a}" h = hash(a) if HASH_DEBUG: s += f"\n<HASHED STR: hash={h}" elif isinstance(a, list): if HASH_DEBUG: s += "\n>HASHING LIST:" h = hash(tuple(_hash_arg(x) for x in a)) if HASH_DEBUG: s += f"\n<HASHED LIST: hash={h}" elif isinstance(a, tuple): if HASH_DEBUG: s += "\n>HASHING TUPLE:" h = hash(tuple(_hash_arg(x) for x in a)) if HASH_DEBUG: s += f"\n<HASHED TUPLE: hash={h}" elif isinstance(a, (set, frozenset)): if HASH_DEBUG: s += "\n>HASHING SET:" h = hash(tuple(_hash_arg(x) for x in sorted(a))) if HASH_DEBUG: s += f"\n<HASHED SET: hash={h}" elif isinstance(a, dict): if HASH_DEBUG: s += "\n>HASHING DICT:" h = hash( tuple((_hash_arg(k), _hash_arg(a[k])) for k in sorted(a.keys())) ) if HASH_DEBUG: s += f"\n<HASHED DICT: hash={h}" elif isinstance(a, npw.ndarray): if HASH_DEBUG: s += "\n>HASHING NDARRAY:" assert a.ndim <= 1 assert a.size < 17, "Only parameters up to size 16 are allowed." hh, ss = self.custom_hash(a.tolist()) h = hh s += ss if HASH_DEBUG: s += f"\n>HASHED NDARRAY: hash={h}" else: h = hash(a) if HASH_DEBUG: s += f"\n>HASHED UNKNOWN TYPE {type(a)}: hash={h}" assert h is not id(a), type(a) return h, s def _hash_karg(k, v): s = "" if k == "mesh_info_vars": # for mesh infos we just hash the code generated constants that # may alter the code branching. if HASH_DEBUG: s += "\n<HASHING MESHINFO" from hysop.backend.device.codegen.base.variables import CodegenStruct check_instance(v, dict, keys=str, values=CodegenStruct) mesh_infos = tuple(str(v[k]) for k in sorted(v.keys())) h = hash(mesh_infos) if HASH_DEBUG: s += "\n MESH INFOS:" for mi in mesh_infos: s += "\n " + mi s += f"\n>HASHED MESHINFO: hash={h}" return h, s elif k == "expr_info": # for expr infos we just hash the continous and discrete expressions # and some additional variables if HASH_DEBUG: s += "\n>HASHING EXPR_INFO:" exprs = tuple(str(e) for e in v.exprs) exprs += tuple(str(e) for e in v.dexprs) extras = (v.name, v.direction, v.has_direction, v.dt_coeff, v.kind) for k in sorted( v.min_ghosts_per_components.keys(), key=lambda x: x.name ): extras += (k.name, _hash_arg(v.min_ghosts_per_components[k])) for mem_obj_key in ( "input_arrays", "output_arrays", "input_buffers", "output_buffers", "input_params", "output_params", ): mem_objects = getattr(v, mem_obj_key) for k in sorted(mem_objects, key=lambda x: x[0]): assert hasattr(mem_objects[k], "short_description"), type( mem_objects[k] ).__mro__ extras += (k, hash(mem_objects[k].short_description())) hh, ss = self.custom_hash(exprs + extras) h = hh s += ss if HASH_DEBUG: s += "\n EXPRESSIONS:" for e in exprs: s += f"\n {e} {type(e)}" s += f"\n with hash {self.custom_hash(e)[1]}" s += "\n EXTRAS:" for e in extras: s += f"\n {e} {type(e)}" s += f"\n with hash {self.custom_hash(e)[1]}" s += f"\n<HASHED EXPR_INFO: hash={h}" return h, s else: msg = f"Unknown custom hash key '{k}'." raise KeyError(msg) def hash_all(*args, **kwds): h, s = None, None if args: h, s = _hash_arg(args[0]) if HASH_DEBUG: s += f"\nHASHED ARGUMENT 0: {h}" for i, arg in enumerate(args[1:]): hh, ss = _hash_arg(arg) h ^= hh if HASH_DEBUG: s += ss s += f"\nHASHED ARGUMENT {i}: {h}" if kwds: items = tuple(sorted(kwds.items(), key=lambda x: x[0])) if h is None: h, s = _hash_karg(*items[0]) else: hh, ss = _hash_karg(*items[0]) h ^= hh if HASH_DEBUG: s += ss s += f"\nHASHED KWD 0: {h}" for i, it in enumerate(items[1:]): hh, ss = _hash_karg(*it) h ^= hh if HASH_DEBUG: s += ss s += f"\nHASHED KWD {i}: {h}" return h, s h, s = hash_all(*args, **kwds) return h, s
[docs] @abstractmethod def autotune( self, name, kernel_args, force_verbose=False, force_debug=False, **extra_kwds ): """Autotune this kernel with given name and extra_kwds.""" pass
[docs] @abstractmethod def max_device_work_dim(self): """Maximum dimensions that specify the global and local work-item IDs.""" pass
[docs] @abstractmethod def max_device_work_group_size(self): """Return the maximum number of work items allowed by the device.""" pass
[docs] @abstractmethod def max_device_work_item_sizes(self): """ Maximum number of work-items that can be specified in each dimension of the work-group. """ pass
[docs] @abstractmethod def compute_args_mapping(self, extra_kwds, extra_parameters): """ Return arguments mapping which is a dictionnary with arguments names as keys and tuples a values. Tuples should contain (arg_position, arg_type(s)) with arg_position being an int and arg_type(s) a type or tuple of types which will be checked against. """ pass
[docs] @abstractmethod def format_best_candidate( self, extra_kwds, extra_parameters, work_load, global_work_size, local_work_size, kernel, kernel_statistics, src_hash, hash_logs, ): """ Post treatment callback for autotuner results. Transform autotuner results in user friendly kernel wrappers. """ pass
[docs] def compute_parameters(self, extra_kwds): """Register extra parameters to optimize.""" return AutotunerParameterConfiguration()
[docs] def compute_work_bounds( self, max_kernel_work_group_size, preferred_work_group_size_multiple, extra_parameters, extra_kwds, work_size=None, work_dim=None, min_work_load=None, max_work_load=None, ): """ Configure work_bounds (work_dim, work_size, max_work_load). Return a WorkBoundsConfiguration object. """ check_instance(max_kernel_work_group_size, int) check_instance(preferred_work_group_size_multiple, int) check_instance(extra_parameters, dict, keys=str) check_instance(extra_kwds, dict, keys=str) assert max_kernel_work_group_size > 0, max_kernel_work_group_size assert ( preferred_work_group_size_multiple > 0 ), preferred_work_group_size_multiple msg = "FATAL ERROR: Could not extract {} from keyword arguments, " msg += "extra_parameters and extra_kwds." msg += f"\nFix {type(self)}::compute_work_bounds()." work_dim = first_not_None( work_dim, extra_parameters.get("work_dim", None), extra_kwds.get("work_dim", None), ) max_work_dim = self.max_device_work_dim() if work_dim is None: msg = msg.format("work_dim") raise RuntimeError(msg) elif work_dim > max_work_dim: msg = "Got work_dim {} but maximum supported by device is {}." msg = msg.format(work_dim, max_work_dim) raise ValueError(msg) work_size = first_not_None( work_size, extra_parameters.get("work_size", None), extra_kwds.get("work_size", None), ) if work_size is None: msg = msg.format("work_size") raise RuntimeError(msg) min_work_load = first_not_None( min_work_load, extra_parameters.get("min_work_load", None), extra_kwds.get("min_work_load", None), (1,) * work_dim, ) max_work_load = first_not_None( max_work_load, extra_parameters.get("max_work_load", None), extra_kwds.get("max_work_load", None), min_work_load, ) assert min_work_load is not None assert max_work_load is not None max_device_work_dim = self.max_device_work_dim() max_device_work_group_size = self.max_device_work_group_size() max_device_work_item_sizes = self.max_device_work_item_sizes() max_work_group_size = min( max_device_work_group_size, max_kernel_work_group_size ) work_bounds = AutotunerWorkBoundsConfiguration( work_dim=work_dim, work_size=work_size, min_work_load=min_work_load, max_work_load=max_work_load, max_device_work_dim=max_device_work_dim, max_device_work_group_size=max_work_group_size, max_device_work_item_sizes=max_device_work_item_sizes, preferred_work_group_size_multiple=preferred_work_group_size_multiple, ) return work_bounds
[docs] def compute_work_candidates( self, work_bounds, work_load, extra_parameters, extra_kwds ): """ Configure work (global_size, local_size candidates) given an AutotunerWorkBoundsConfiguration instance and a work_load. Return a OpenClWorkConfiguration instance. """ check_instance(work_bounds, AutotunerWorkBoundsConfiguration) check_instance( work_load, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim ) check_instance(extra_parameters, dict, keys=str) check_instance(extra_kwds, dict, keys=str) global_work_size = (work_bounds.work_size + work_load - 1) // work_load (min_wg_size, max_wg_size) = self.compute_min_max_wg_size( work_bounds=work_bounds, work_load=work_load, global_work_size=global_work_size, extra_parameters=extra_parameters, extra_kwds=extra_kwds, ) work = AutotunerWorkConfiguration( work_bounds=work_bounds, work_load=work_load, min_wg_size=min_wg_size, max_wg_size=max_wg_size, ) return work
[docs] def compute_min_max_wg_size( self, work_bounds, work_load, global_work_size, extra_parameters, extra_kwds ): """Default min and max workgroup size.""" min_wg_size = npw.ones(shape=work_bounds.work_dim, dtype=npw.int32) max_wg_size = global_work_size.copy() return (min_wg_size, max_wg_size)
[docs] def hash_extra_kwds(self, extra_kwds): """Hash extra_kwds dictionnary for caching purposes.""" for k, v in extra_kwds.items(): try: h = hash(v) if h == hash(id(v)): hashable = "hash is id" else: hashable = "hashable" except: hashable = "hash failed" print(k, type(v), hashable) raise NotImplementedError(f"{type(self).__name__}.hash_extra_kwds()")
[docs] def hash_extra_parameters(self, extra_parameters): """Hash extra_parameters dictionnary for caching purposes.""" for k, v in extra_parameters.items(): if hash(v) == hash(id(v)): msg = "Parameter {} of type {} is not safe to hash." msg += ( "\nImplement a {}.__hash__() to that it depends only on its values " ) msg += "and not its instance id." msg = msg.format(t, type(v), str(type(v))) raise RuntimeError(msg) items = tuple(sorted(extra_parameters.items(), key=lambda x: x[0])) return hash(frozenset(items))
[docs] @abstractmethod def compute_global_work_size( self, work, local_work_size, extra_parameters, extra_kwds ): """ Compute aligned global_work_size from unaligned global_work_size and local_work_size. """ pass
[docs] @abstractmethod def generate_kernel_src( self, global_work_size, local_work_size, extra_parameters, extra_kwds, tuning_mode, dry_run, ): """ Generate kernel source code as a string. """ pass
@classmethod def _check_build_configuration(cls, autotuner_config, build_opts): """Check autotuner_config and build options.""" check_instance(autotuner_config, KernelAutotunerConfig) check_instance(build_opts, tuple)
[docs] @classmethod def check_cartesian_field( cls, field, dtype=None, size=None, resolution=None, compute_resolution=None, nb_components=None, ghosts=None, min_ghosts=None, max_ghosts=None, domain=None, topology=None, ): check_instance(field, CartesianDiscreteScalarFieldView) if (domain is not None) and (field.domain.domain is not domain): msg = "Domain mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (topology is not None) and (field.topology.topology is not topology): msg = "Topology mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (size is not None) and (field.npoints != size): msg = "Size mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (resolution is not None) and any(field.resolution != resolution): msg = "Resolution mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (compute_resolution is not None) and any( field.compute_resolution != compute_resolution ): msg = "Local resolution mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (dtype is not None) and (field.dtype != dtype): msg = "dtype mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (nb_components is not None) and (field.nb_components != nb_components): msg = "nb_components mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (ghosts is not None) and (field.ghosts != ghosts): msg = "ghosts mismatch for dfield {}." msg = msg.format(field.name) raise RuntimeError(msg) if (min_ghosts is not None) and npw.any(field.ghosts < min_ghosts): msg = "Min ghosts mismatch for dfield {}, expected {} got {}." msg = msg.format(field.name, min_ghosts, field.ghosts) raise RuntimeError(msg) if (max_ghosts is not None) and npw.any(field.ghosts > max_ghosts): msg = "max ghosts mismatch for dfield {}, expected {} got {}." msg = msg.format(field.name, max_ghosts, field.ghosts) raise RuntimeError(msg)
[docs] @classmethod def check_cartesian_fields(cls, *fields, **kwds): """ Check that given fields are compatible (defined on the same domain) By default, also compare dtypes, number of components and size. Checks can be enabled or disabled by using check_[res,cres,size,components,dtype] as boolean keyword arguments. """ check_instance( fields, tuple, values=CartesianDiscreteScalarFieldView, minsize=1 ) check_resolution = kwds.pop("check_res", False) check_compute_resolution = kwds.pop("check_cres", False) check_size = kwds.pop("check_size", True) check_nb_components = kwds.pop("check_components", True) check_dtype = kwds.pop("check_dtype", True) assert not kwds, f"Unused keyword arguments {kwds.keys()}." domain = fields[0].domain resolution = fields[0].compute_resolution dtype = fields[0].dtype size = fields[0].npoints nb_components = fields[0].nb_components for field in fields: if field.domain.domain is not domain.domain: msg = "Domain mismatch between dfield {} and dfield {}." msg = msg.format(fields[0].name, field.name) raise RuntimeError(msg) if check_size and (field.npoints != size): msg = "Size mismatch between dfield {} and dfield {}." msg = msg.format(fields[0].name, field.name) raise RuntimeError(msg) if check_resolution and any(field.resolution != resolution): msg = "Resolution mismatch between dfield {} and dfield {}." msg = msg.format(fields[0].name, field.name) raise RuntimeError(msg) if check_compute_resolution and any( field.compute_resolution != compute_resolution ): msg = "Local resolution mismatch between dfield {} and dfield {}." msg = msg.format(fields[0].name, field.name) raise RuntimeError(msg) if check_dtype and (field.dtype != dtype): msg = "dtype mismatch between dfield {} and dfield {}." msg = msg.format(fields[0].name, field.name) raise RuntimeError(msg) if check_nb_components and (field.nb_components != nb_components): msg = "nb_components mismatch between dfield {} and dfield {}." msg = msg.format(fields[0].name, field.name) raise RuntimeError(msg)
[docs] def mesh_info(self, name, mesh): """Create a MeshInfoStruct from a CartesianMesh.""" return MeshInfoStruct.create_from_mesh( name=name, mesh=mesh, typegen=self.typegen )[1]
[docs] def input_mesh_info(self, field): """Create a MeshInfoStruct for an input DisreteCartesianField.""" name = f"{field.name}_in_field_mesh_info" return self.mesh_info(name=name, mesh=field.mesh.mesh)
[docs] def output_mesh_info(self, field): """Create a MeshInfoStruct for an output DisreteCartesianField.""" name = f"{field.name}_out_field_mesh_info" return self.mesh_info(name=name, mesh=field.mesh.mesh)
[docs] class AutotunerParameterConfiguration: """Helper class for kernel autotuning to handle extra parameters.""" def __init__(self, **kwds): super().__init__(**kwds) self._param_names = () self._parameters = {} def _get_parameter_names(self): return self._param_names def _get_parameters(self): return self._parameters param_names = property(_get_parameter_names) parameters = property(_get_parameters)
[docs] def register_extra_parameter(self, param_name, candidate_values): check_instance(param_name, str) if param_name in self._param_names: msg = "Parameter {} has already been registered." msg = msg.format(param_name) raise RuntimeError(msg) candidate_values = tuple(candidate_values) if len(candidate_values) == 0: msg = "candidates_values is empty." raise ValueError(msg) self._param_names += (param_name,) self._parameters[param_name] = candidate_values
[docs] def iter_parameters(self): param_names = self._param_names param_values = tuple(self._parameters[pname] for pname in param_names) param_iterator = it.product(*param_values) for params in param_iterator: extra_parameters = dict(zip(param_names, params)) yield extra_parameters
[docs] class AutotunerWorkBoundsConfiguration: """Helper class for kernel autotuning to handle work bounds.""" def __init__( self, work_dim, work_size, min_work_load, max_work_load, max_device_work_dim, max_device_work_group_size, max_device_work_item_sizes, preferred_work_group_size_multiple, **kwds, ): super().__init__(**kwds) assert ( work_dim <= max_device_work_dim ), f"work_dim {work_dim} > {max_device_work_dim}" work_dim = int(work_dim) assert work_dim > 0 assert ( preferred_work_group_size_multiple > 0 ), preferred_work_group_size_multiple work_size = npw.asarray(work_size, dtype=npw.int32) min_work_load = npw.asarray(min_work_load, dtype=npw.int32) max_work_load = npw.asarray(max_work_load, dtype=npw.int32) check_instance(work_size, npw.ndarray, dtype=npw.int32, size=work_dim) check_instance(min_work_load, npw.ndarray, dtype=npw.int32, size=work_dim) check_instance(max_work_load, npw.ndarray, dtype=npw.int32, size=work_dim) assert (work_size > 0).all() assert (min_work_load > 0).all() assert (max_work_load >= min_work_load).all() self._work_dim = work_dim self._work_size = work_size self._min_work_load = min_work_load self._max_work_load = max_work_load self._max_device_work_dim = int(max_device_work_dim) self._max_device_work_group_size = int(max_device_work_group_size) self._max_device_work_item_sizes = npw.asarray( max_device_work_item_sizes[:work_dim], dtype=npw.int32 ) self._preferred_work_group_size_multiple = preferred_work_group_size_multiple self._generate_work_loads() def _get_work_dim(self): return self._work_dim def _get_work_size(self): return self._work_size def _get_min_work_load(self): return self._min_work_load def _get_max_work_load(self): return self._max_work_load def _get_max_device_work_dim(self): return self._max_device_work_dim def _get_max_device_work_group_size(self): return self._max_device_work_group_size def _get_max_device_work_item_sizes(self): return self._max_device_work_item_sizes def _get_preferred_work_group_size_multiple(self): return self._preferred_work_group_size_multiple work_dim = property(_get_work_dim) work_size = property(_get_work_size) min_work_load = property(_get_min_work_load) max_work_load = property(_get_max_work_load) max_device_work_dim = property(_get_max_device_work_dim) max_device_work_group_size = property(_get_max_device_work_group_size) max_device_work_item_sizes = property(_get_max_device_work_item_sizes) preferred_work_group_size_multiple = property( _get_preferred_work_group_size_multiple ) def _generate_work_loads(self): work_size = self.work_size min_work_load, max_work_load = self.min_work_load, self.max_work_load min_work_load = npw.minimum(min_work_load, work_size) max_work_load = npw.minimum(max_work_load, work_size) def _compute_pows(minw, maxw): res = [] wl = minw while wl < maxw: res.append(wl) wl = next_pow2(wl) res.append(maxw) res = tuple(res) return res work_loads = tuple( _compute_pows(min_w, max_w) for (min_w, max_w) in zip(min_work_load.tolist(), max_work_load.tolist()) ) work_loads = it.product(*work_loads) self._work_loads = tuple(work_loads)
[docs] def iter_work_loads(self): for wl in self._work_loads: yield npw.asarray(wl, dtype=npw.int32)
[docs] class AutotunerWorkConfiguration: __debug_filters = False def __init__( self, work_bounds, work_load, min_wg_size, max_wg_size, ordered_workload=True ): check_instance(work_bounds, AutotunerWorkBoundsConfiguration) check_instance( work_load, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim ) check_instance( min_wg_size, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim ) check_instance( max_wg_size, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim ) assert (min_wg_size >= 1).all(), f"min_wg_size = {min_wg_size}" assert (min_wg_size <= max_wg_size).all(), f"{min_wg_size} > {max_wg_size}" self._work_bounds = work_bounds self._work_load = work_load self._global_work_size = (work_bounds.work_size + work_load - 1) // work_load self._filters = {} self._filter_names = () self._min_wg_size = min_wg_size self._max_wg_size = max_wg_size self._local_work_size_generator = self._default_work_size_generator self._generate_unfiltered_candidates() self._load_default_filters(work_bounds, ordered_workload) def _get_work_bounds(self): return self._work_bounds def _get_work_load(self): return self._work_load def _get_global_work_size(self): return self._global_work_size def _get_filters(self): return self._filters def _get_filter_names(self): return self._filter_names def _get_work_dim(self): return self._work_bounds.work_dim work_bounds = property(_get_work_bounds) work_load = property(_get_work_load) work_dim = property(_get_work_dim) global_work_size = property(_get_global_work_size) filters = property(_get_filters) filter_names = property(_get_filter_names) def _generate_unfiltered_candidates(self): candidates = self._local_work_size_generator() check_instance(candidates, tuple, values=npw.ndarray, minsize=1) self._unfiltered_candidates = candidates def _default_work_size_generator(self): """Default local_work_size generator.""" pows = [] size = 1 min_wi_size = self._min_wg_size max_wi_size = self._max_wg_size def _compute_pows(min_wi, max_wi): res = [] wi = min_wi while wi < max_wi: res.append(wi) wi = next_pow2(wi) res.append(max_wi) res = tuple(res) return res work_items = tuple( _compute_pows(min_wi, max_wi)[::-1] for (min_wi, max_wi) in zip(min_wi_size.tolist(), max_wi_size.tolist()) ) wi_candidates = it.product(*work_items) return tuple(npw.asarray(wi, dtype=npw.int32) for wi in wi_candidates)
[docs] def set_local_work_size_generator(self, fn): """ Set a custom local_work_size generator that will generated a set of local_work_sizes to be filtered. """ assert callable(fn) self._local_work_size_generator = fn self._generate_unfiltered_candidates()
[docs] def iter_local_work_size(self): """Iterates over filtered work sizes.""" candidates = self._unfiltered_candidates if self.__debug_filters: msg = " *Initial workitems candidates:\n {}\n".format( tuple(tuple(x) for x in candidates) ) print(msg) for fname in self.filter_names: fn = self._filters[fname] candidates = tuple(filter(fn, candidates)) if self.__debug_filters: candidates, _ = it.tee(candidates) msg = " *Filter {}:\n {}\n".format(fname, tuple(tuple(x) for x in _)) print(msg) return candidates
[docs] def push_filter(self, filter_name, filter_fn, **filter_kwds): """Push a named local_work_size filter with custom keywords.""" check_instance(filter_name, str) assert callable(filter_fn) if filter_name in self._filter_names: msg = "Filter {} has already been registered." msg = msg.format(filter_name) raise RuntimeError(msg) filter_fn = functools.partial(filter_fn, **filter_kwds) self._filter_names += (filter_name,) self._filters[filter_name] = filter_fn
def _load_default_filters(self, work_bounds, ordered_workload): """Load default local_work_size filters (mostly device limitations.)""" self.push_filter( f"max_device_work_item_sizes (default filter, max_work_item_sizes={work_bounds.max_device_work_item_sizes})", self.max_wi_sizes_filter, max_work_item_sizes=work_bounds.max_device_work_item_sizes, ) self.push_filter( f"max_device_work_group_size (default filter, max_device_work_group_size={work_bounds.max_device_work_group_size})", self.max_wg_size_filter, max_work_group_size=work_bounds.max_device_work_group_size, ) if ordered_workload: self.push_filter("ordered_workload (default)", self.ordered_workload_filter)
[docs] @staticmethod def max_wi_sizes_filter(local_work_size, max_work_item_sizes): """Filter out work items by size given a maximum size.""" return (local_work_size <= max_work_item_sizes).all()
[docs] @staticmethod def min_wi_sizes_filter(local_work_size, min_work_item_sizes): """Filter out work items by size given a minimum size.""" return (local_work_size >= min_work_item_sizes).all()
[docs] @staticmethod def max_wg_size_filter(local_work_size, max_work_group_size): """Filter out work items by workgroup size given a maximum workgroup size.""" return npw.prod(local_work_size, dtype=npw.int64) <= max_work_group_size
[docs] @staticmethod def ordered_workload_filter(local_work_size): """Filter out work items by decreasing dimensional sizes.""" oldval = local_work_size[0] for val in local_work_size[1:]: if val > oldval: return False oldval = val return True
[docs] @abstractmethod def make_parameter(self, param): pass
[docs] @abstractmethod def make_array_offset(self, dim): pass
[docs] @abstractmethod def make_array_strides(self, dim): pass
[docs] @abstractmethod def make_array_args(self, **arrays): pass
[docs] @abstractmethod def make_dt(self, dtype): pass